package db_gorm

import (
	"context"
	"errors"

	"gitee.com/lipore/plume/trx"

	"gorm.io/gorm"
)

type GormTransaction struct {
	context.Context
	conn *gorm.DB
}

type GormTxManager struct {
	conn *gorm.DB
}

func NewTxManager(conn *gorm.DB) trx.Manager {
	return &GormTxManager{
		conn: conn,
	}
}

func (t *GormTxManager) Start(ctx context.Context) (trx.Context, error) {
	d := t.conn.WithContext(ctx).Begin()
	tx := &GormTransaction{
		Context: ctx,
		conn:    d,
	}
	return tx, nil
}

func (t *GormTransaction) Commit() error {
	err := t.conn.Commit().Error
	if err != nil {
		return err
	}
	return nil
}

func (t *GormTransaction) Rollback() {
	t.conn.Rollback()
}

func (t *GormTransaction) RollbackTo(name string) {
	t.conn.RollbackTo(name)
}

func (t *GormTransaction) SavePoint(name string) {
	t.conn.SavePoint(name)
}

func LoadTx(ctx trx.Context) (*gorm.DB, error) {
	if tx, ok := ctx.(*GormTransaction); ok {
		return tx.conn, nil
	} else {
		return nil, errors.New("not supported transaction type")
	}
}

func WithTx(ctx context.Context, fn func(tx *gorm.DB) error) error {
	return trx.Tx(ctx, func(ctx trx.Context) error {
		tx, err := LoadTx(ctx)
		if err != nil {
			return err
		}
		return fn(tx)
	})
}
